{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch version: 2.2.2\n",
      "tiktoken version: 0.7.0\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "import tiktoken\n",
    "import torch\n",
    "\n",
    "print(\"torch version:\", version(\"torch\"))\n",
    "print(\"tiktoken version:\", version(\"tiktoken\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "\n",
    "if not os.path.exists(\"the-verdict.txt\"):\n",
    "    url = (\"https://raw.githubusercontent.com/rasbt/\"\n",
    "           \"LLMs-from-scratch/main/ch02/01-main-chapter-code/\"\n",
    "           \"the-verdict.txt\")\n",
    "    file_path = \"the-verdict.txt\"\n",
    "    urllib.request.urlretrieve(url, file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of cha: 20479\n",
      "I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no \n"
     ]
    }
   ],
   "source": [
    "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
    "    raw_text = f.read()\n",
    "\n",
    "print(\"Total number of cha:\", len(raw_text))\n",
    "print(raw_text[:99])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Hello,', ' ', 'world.', ' ', 'This,', ' ', 'is', ' ', 'a', ' ', 'test.']\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "text = \"Hello, world. This, is a test.\"\n",
    "result = re.split(r'(\\s)', text)\n",
    "\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Hello', ',', '', ' ', 'world', '.', '', ' ', 'This', ',', '', ' ', 'is', ' ', 'a', ' ', 'test', '.', '']\n"
     ]
    }
   ],
   "source": [
    "result = re.split(r'([,.]|\\s)', text)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Hello', ',', 'world', '.', 'This', ',', 'is', 'a', 'test', '.']\n"
     ]
    }
   ],
   "source": [
    "result = [item for item in result if item.strip()]\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Hello', ',', '', ' ', 'world', '.', '', ' ', 'Is', ' ', 'this', '--', '', ' ', 'a', ' ', 'test', '?', '']\n"
     ]
    }
   ],
   "source": [
    "text = \"Hello, world. Is this-- a test?\"\n",
    "result = re.split(r'([,.:;?_!\"()\\']|--|\\s)', text)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['I', 'HAD', 'always', 'thought', 'Jack', 'Gisburn', 'rather', 'a', 'cheap', 'genius', '--', 'though', 'a', 'good', 'fellow', 'enough', '--', 'so', 'it', 'was', 'no', 'great', 'surprise', 'to', 'me', 'to', 'hear', 'that', ',', 'in']\n"
     ]
    }
   ],
   "source": [
    "preprocessed = re.split(r'([,.:;?_!\"()\\']|--|\\s)',raw_text)\n",
    "preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
    "print(preprocessed[:30])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4690\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(len(preprocessed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1130\n",
      "['!', '\"', \"'\", '(', ')', ',', '--', '.', ':', ';', '?', 'A', 'Ah', 'Among', 'And', 'Are', 'Arrt', 'As', 'At', 'Be', 'Begin', 'Burlington', 'But', 'By', 'Carlo', 'Chicago', 'Claude', 'Come', 'Croft', 'Destroyed', 'Devonshire', 'Don', 'Dubarry', 'Emperors', 'Florence', 'For', 'Gallery', 'Gideon', 'Gisburn', 'Gisburns', 'Grafton', 'Greek', 'Grindle', 'Grindles', 'HAD', 'Had', 'Hang', 'Has', 'He', 'Her']\n"
     ]
    }
   ],
   "source": [
    "all_words = sorted(set(preprocessed))\n",
    "vocab_size = len(all_words)\n",
    "\n",
    "print(vocab_size)\n",
    "print(all_words[:50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('!', 0)\n",
      "('\"', 1)\n",
      "(\"'\", 2)\n",
      "('(', 3)\n",
      "(')', 4)\n",
      "(',', 5)\n",
      "('--', 6)\n",
      "('.', 7)\n",
      "(':', 8)\n",
      "(';', 9)\n",
      "('?', 10)\n",
      "('A', 11)\n",
      "('Ah', 12)\n",
      "('Among', 13)\n",
      "('And', 14)\n",
      "('Are', 15)\n",
      "('Arrt', 16)\n",
      "('As', 17)\n",
      "('At', 18)\n",
      "('Be', 19)\n",
      "('Begin', 20)\n",
      "('Burlington', 21)\n",
      "('But', 22)\n",
      "('By', 23)\n",
      "('Carlo', 24)\n",
      "('Chicago', 25)\n",
      "('Claude', 26)\n",
      "('Come', 27)\n",
      "('Croft', 28)\n",
      "('Destroyed', 29)\n",
      "('Devonshire', 30)\n",
      "('Don', 31)\n",
      "('Dubarry', 32)\n",
      "('Emperors', 33)\n",
      "('Florence', 34)\n",
      "('For', 35)\n",
      "('Gallery', 36)\n",
      "('Gideon', 37)\n",
      "('Gisburn', 38)\n",
      "('Gisburns', 39)\n",
      "('Grafton', 40)\n",
      "('Greek', 41)\n",
      "('Grindle', 42)\n",
      "('Grindles', 43)\n",
      "('HAD', 44)\n",
      "('Had', 45)\n",
      "('Hang', 46)\n",
      "('Has', 47)\n",
      "('He', 48)\n",
      "('Her', 49)\n",
      "('Hermia', 50)\n"
     ]
    }
   ],
   "source": [
    "vocab = {token:integer for integer,token in enumerate(all_words)}\n",
    "\n",
    "for i, item in enumerate(vocab.items()):\n",
    "    print(item)\n",
    "    if i>= 50:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleTokenizerV1:\n",
    "    def __init__(self, vocab):\n",
    "        self.str_to_int = vocab\n",
    "        self.int_to_str = {i:s for s,i in vocab.items()}\n",
    "    \n",
    "    def encode(self, text):\n",
    "        preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', text)\n",
    "        preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
    "        ids = [self.str_to_int[s] for s in preprocessed]\n",
    "        return ids\n",
    "    \n",
    "    def decode(self, ids):\n",
    "        text = \" \".join([self.int_to_str[i] for i in ids])\n",
    "        text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text)\n",
    "        return text\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 56, 2, 850, 988, 602, 533, 746, 5, 1126, 596, 5, 1, 67, 7, 38, 851, 1108, 754, 793, 7]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = SimpleTokenizerV1(vocab)\n",
    "\n",
    "text=\"\"\"\"It's the last he painted, you know,\"\n",
    "         Mrs. Gisburn said with pardonable pride.\"\"\"\n",
    "ids = tokenizer.encode(text)\n",
    "print(ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\" It\\' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1132"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_token = sorted(list(set(preprocessed)))\n",
    "all_token.extend([\"<|endoftext|>\", \"<|unk|>\"])\n",
    "\n",
    "vocab = {token:integer for integer,token in enumerate(all_token)}\n",
    "len(vocab.items())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('younger', 1127)\n",
      "('your', 1128)\n",
      "('yourself', 1129)\n",
      "('<|endoftext|>', 1130)\n",
      "('<|unk|>', 1131)\n"
     ]
    }
   ],
   "source": [
    "for i,item in enumerate(list(vocab.items())[-5:]):\n",
    "    print(item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleTokenizerV2:\n",
    "    def __init__(self, vocab):\n",
    "        self.str_to_int = vocab\n",
    "        self.int_to_str = {i:s for s,i in vocab.items()}\n",
    "\n",
    "    def encode(self, text):\n",
    "        preprocessed = re.split(r'([,.?_!\"()\\']|--|\\s)', text)\n",
    "        preprocessed = [item.strip() for item in preprocessed if item.strip()]\n",
    "        preprocessed = [\n",
    "            item if item in self.str_to_int\n",
    "            else \"<|unk|>\" for item in preprocessed\n",
    "        ]\n",
    "        ids = [self.str_to_int[s] for s in preprocessed]\n",
    "        return ids\n",
    "    \n",
    "    def decode(self, ids):\n",
    "        text = \" \".join([self.int_to_str[i] for i in ids])\n",
    "        text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text)\n",
    "        return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello, do you like tea <|endoftext|> In the sunlit terraces of palace.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = SimpleTokenizerV2(vocab)\n",
    "\n",
    "text1 = \"Hello, do you like tea\"\n",
    "text2 = \"In the sunlit terraces of palace.\"\n",
    "\n",
    "text = \" <|endoftext|> \".join((text1, text2))\n",
    "\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1131, 5, 355, 1126, 628, 975, 1130, 55, 988, 956, 984, 722, 1131, 7]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|unk|>, do you like tea <|endoftext|> In the sunlit terraces of <|unk|>.'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenizer.encode(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tiktoken viersion: 0.7.0\n"
     ]
    }
   ],
   "source": [
    "import importlib\n",
    "import importlib.metadata\n",
    "import tiktoken\n",
    "\n",
    "print(\"tiktoken viersion:\", importlib.metadata.version(\"tiktoken\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[15496, 11, 466, 345, 588, 8887, 30, 554, 262, 4252, 18250, 1059, 430, 286, 617, 34680, 27271, 13]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
    "\n",
    "text = (\n",
    "    \"Hello, do you like tea? In the sunlit terra of someunknownPlace.\"\n",
    ")\n",
    "integers = tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n",
    "print(integers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello, do you like tea? In the sunlit terra of someunknownPlace.\n"
     ]
    }
   ],
   "source": [
    "strings = tokenizer.decode(integers)\n",
    "print(strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5145\n"
     ]
    }
   ],
   "source": [
    "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
    "    raw_text = f.read()\n",
    "\n",
    "enc_text = tokenizer.encode(raw_text)\n",
    "print(len(enc_text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:[290, 4920, 2241, 287]\n",
      "y:     [4920, 2241, 287, 257]\n"
     ]
    }
   ],
   "source": [
    "enc_sample = enc_text[50:]\n",
    "\n",
    "context_size = 4\n",
    "x = enc_sample[:context_size]\n",
    "y = enc_sample[1:context_size+1]\n",
    "\n",
    "print(f\"x:{x}\")\n",
    "print(f\"y:     {y}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[290] ----> 4920\n",
      "[290, 4920] ----> 2241\n",
      "[290, 4920, 2241] ----> 287\n",
      "[290, 4920, 2241, 287] ----> 257\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, context_size+1):\n",
    "    context = enc_sample[:i]\n",
    "    desired = enc_sample[i]\n",
    "\n",
    "    print(context, \"---->\", desired)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " and --->  established\n",
      " and established --->  himself\n",
      " and established himself --->  in\n",
      " and established himself in --->  a\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, context_size+1):\n",
    "    context = enc_sample[:i]\n",
    "    desired = enc_sample[i]\n",
    "    print(tokenizer.decode(context), \"--->\", tokenizer.decode([desired]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pytorch version: 2.2.2\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(\"Pytorch version:\", torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class GPTDatasetV1(Dataset):\n",
    "    def __init__(self, txt, tokenizer, max_length, stride):\n",
    "        self.input_ids = []\n",
    "        self.target_ids = []\n",
    "\n",
    "        token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n",
    "\n",
    "        for i in range(0, len(token_ids)-max_length, stride):\n",
    "            input_chunk = token_ids[i:i+max_length]\n",
    "            target_chunk = token_ids[i+1:i+max_length+1]\n",
    "            self.input_ids.append(torch.tensor(input_chunk))\n",
    "            self.target_ids.append(torch.tensor(target_chunk))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.input_ids)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.input_ids[idx], self.target_ids[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataloader_v1(txt, batch_size=4, max_length=256,stride=128,shuffle=True,drop_last=True,num_workers=0):\n",
    "    tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
    "\n",
    "    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
    "\n",
    "    dataloader = DataLoader(\n",
    "        dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=shuffle,\n",
    "        drop_last=drop_last,\n",
    "        num_workers=num_workers\n",
    "    )\n",
    "\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"the-verdict.txt\",\"r\",encoding=\"utf-8\") as f:\n",
    "    raw_text = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[  40,  367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]\n"
     ]
    }
   ],
   "source": [
    "dataloader = create_dataloader_v1(\n",
    "    raw_text, batch_size=1, max_length=4,stride=1,shuffle=False\n",
    ")\n",
    "data_iter = iter(dataloader)\n",
    "first_batch = next(data_iter)\n",
    "print(first_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([[ 367, 2885, 1464, 1807]]), tensor([[2885, 1464, 1807, 3619]])]\n"
     ]
    }
   ],
   "source": [
    "second_batch = next(data_iter)\n",
    "print(second_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inputs:\n",
      " tensor([[   40,   367,  2885,  1464],\n",
      "        [  367,  2885,  1464,  1807],\n",
      "        [ 2885,  1464,  1807,  3619],\n",
      "        [ 1464,  1807,  3619,   402],\n",
      "        [ 1807,  3619,   402,   271],\n",
      "        [ 3619,   402,   271, 10899],\n",
      "        [  402,   271, 10899,  2138],\n",
      "        [  271, 10899,  2138,   257]])\n",
      "\n",
      "Targets:\n",
      " tensor([[  367,  2885,  1464,  1807],\n",
      "        [ 2885,  1464,  1807,  3619],\n",
      "        [ 1464,  1807,  3619,   402],\n",
      "        [ 1807,  3619,   402,   271],\n",
      "        [ 3619,   402,   271, 10899],\n",
      "        [  402,   271, 10899,  2138],\n",
      "        [  271, 10899,  2138,   257],\n",
      "        [10899,  2138,   257,  7026]])\n"
     ]
    }
   ],
   "source": [
    "dataloader = create_dataloader_v1(raw_text,batch_size=8,max_length=4,stride=1,shuffle=False)\n",
    "\n",
    "data_iter = iter(dataloader)\n",
    "inputs, targets = next(data_iter)\n",
    "print(\"Inputs:\\n\", inputs)\n",
    "print(\"\\nTargets:\\n\", targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = torch.tensor([2,3,5,1])\n",
    "\n",
    "vocab_size = 6\n",
    "output_dim = 3\n",
    "\n",
    "torch.manual_seed(123)\n",
    "embedding_layer = torch.nn.Embedding(vocab_size,output_dim)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[ 0.3374, -0.1778, -0.1690],\n",
      "        [ 0.9178,  1.5810,  1.3010],\n",
      "        [ 1.2753, -0.2010, -0.1606],\n",
      "        [-0.4015,  0.9666, -1.1481],\n",
      "        [-1.1589,  0.3255, -0.6315],\n",
      "        [-2.8400, -0.7849, -1.4096]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "print(embedding_layer.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.4015,  0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(embedding_layer(torch.tensor([3])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor([4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.9178, 1.5810, 1.3010]], grad_fn=<EmbeddingBackward0>)\n",
      "tensor([[ 1.2753, -0.2010, -0.1606]], grad_fn=<EmbeddingBackward0>)\n",
      "tensor([[-0.4015,  0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(embedding_layer(torch.tensor([1])))\n",
    "print(embedding_layer(torch.tensor([2])))\n",
    "print(embedding_layer(torch.tensor([3])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.2753, -0.2010, -0.1606],\n",
      "        [-0.4015,  0.9666, -1.1481],\n",
      "        [-2.8400, -0.7849, -1.4096],\n",
      "        [ 0.9178,  1.5810,  1.3010]], grad_fn=<EmbeddingBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(embedding_layer(input_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token ids:\n",
      " tensor([[   40,   367,  2885,  1464],\n",
      "        [ 1807,  3619,   402,   271],\n",
      "        [10899,  2138,   257,  7026],\n",
      "        [15632,   438,  2016,   257],\n",
      "        [  922,  5891,  1576,   438],\n",
      "        [  568,   340,   373,   645],\n",
      "        [ 1049,  5975,   284,   502],\n",
      "        [  284,  3285,   326,    11]])\n",
      "\n",
      "INputs shape:\n",
      " torch.Size([8, 4])\n"
     ]
    }
   ],
   "source": [
    "max_length = 4\n",
    "dataloader = create_dataloader_v1(\n",
    "    raw_text, batch_size=8, max_length=max_length,\n",
    "    stride=max_length, shuffle=False\n",
    ")\n",
    "data_iter = iter(dataloader)\n",
    "inputs, targets = next(data_iter)\n",
    "\n",
    "print(\"Token ids:\\n\", inputs)\n",
    "print(\"\\nINputs shape:\\n\", inputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 4, 256])\n"
     ]
    }
   ],
   "source": [
    "vocab_size = 50257\n",
    "output_dim = 256\n",
    "token_embedding_layer = torch.nn.Embedding(vocab_size,output_dim)\n",
    "\n",
    "token_embedding = token_embedding_layer(inputs)\n",
    "print(token_embedding.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 256])\n"
     ]
    }
   ],
   "source": [
    "context_length = max_length\n",
    "\n",
    "pos_embedding_layer = torch.nn.Embedding(context_length,output_dim)\n",
    "\n",
    "pos_embedding = pos_embedding_layer(torch.arange(max_length))\n",
    "print(pos_embedding.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 4, 256])\n",
      "tensor([[[ 3.1963, -0.9273, -0.6706,  ..., -0.3722,  1.9582,  1.3436],\n",
      "         [ 1.3767, -0.9999, -2.3134,  ..., -0.7727,  2.8536,  0.4558],\n",
      "         [ 0.2037, -0.0458,  0.9695,  ...,  1.3806, -0.4898, -0.3085],\n",
      "         [ 0.9437,  0.4423,  0.9439,  ...,  1.1623, -1.3224,  0.8711]],\n",
      "\n",
      "        [[ 0.9527, -0.4987, -0.9977,  ..., -0.5698,  1.1279, -0.4734],\n",
      "         [ 1.2611, -1.2349,  0.3647,  ...,  1.0208,  0.2363, -0.0126],\n",
      "         [-0.6065,  0.8131,  0.6300,  ...,  2.8813, -0.8168,  2.3148],\n",
      "         [-2.2961, -1.2784,  1.1127,  ...,  1.2108,  0.1352,  0.5271]],\n",
      "\n",
      "        [[ 1.9038, -1.4537,  2.6584,  ...,  0.5149,  1.0256,  3.2186],\n",
      "         [ 2.9345,  0.4008, -0.7399,  ...,  0.7323, -0.0897, -1.4153],\n",
      "         [ 0.1099, -0.7525,  0.4368,  ...,  1.0700, -0.6381,  2.1148],\n",
      "         [-0.1068, -0.1594,  2.7963,  ...,  1.6612, -1.9211, -0.1414]],\n",
      "\n",
      "        ...,\n",
      "\n",
      "        [[-0.9062,  0.1846, -0.3344,  ...,  1.4106,  1.5549,  0.4091],\n",
      "         [ 0.3260, -1.9443,  0.6698,  ...,  1.3864,  0.1763,  0.6049],\n",
      "         [-0.2671, -1.3030,  1.8981,  ...,  3.0627, -0.4543,  2.6285],\n",
      "         [-2.0541, -1.4030,  3.5021,  ...,  0.7566, -1.8600,  1.3512]],\n",
      "\n",
      "        [[ 3.4548, -0.3311, -2.5323,  ...,  1.2981,  1.4370,  0.0474],\n",
      "         [ 2.0839, -0.9638, -0.2766,  ...,  2.2872,  2.2334,  1.0031],\n",
      "         [ 0.8929,  0.1006,  0.6659,  ...,  0.2519, -0.8672,  1.7256],\n",
      "         [-2.1924, -0.8753,  2.2379,  ...,  0.6741, -0.5158, -0.8250]],\n",
      "\n",
      "        [[ 3.0954,  0.3144, -0.5450,  ..., -1.4092,  1.2040,  1.9022],\n",
      "         [ 0.6127, -2.7314, -0.5301,  ...,  0.2990,  0.1086, -0.9790],\n",
      "         [ 1.6071, -0.1774,  1.3220,  ...,  1.1332,  0.3072, -0.1389],\n",
      "         [-0.8949,  0.2133,  2.5795,  ...,  0.7440,  0.8799,  0.7419]]],\n",
      "       grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "input_embeddings = token_embedding+pos_embedding\n",
    "print(input_embeddings.shape)\n",
    "\n",
    "print(input_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "LLMs",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
